# =============================================================================
# Statistical Analysis Code for:
# "Chronic adaptive versus conventional DBS response patterns in Parkinson's
#  disease: A pilot randomized crossover trial"
#
# Visualization: Treatment effect modifier heatmap
# Author: Jun Tanimura, MD, MSc.
# Created: 2025-08-08
# Last Updated: 2025-09-07
# Version: v3.3
# =============================================================================

# Load required packages
library(tidyverse)
library(ggplot2)
library(viridis)
library(scales)

# Read the unified modifier results and filter out Sex
modifier_data <- read_csv("results_treatment_effect_modifiers_unified_fixed.csv") %>%
  filter(!str_detect(Modifier, "^sex"))

# Clean modifier names for explicit "Better Baseline" labeling
modifier_data_clean <- modifier_data %>%
  mutate(
    # Clean modifier names (removed arrows)
    Modifier_Clean = case_when(
      str_detect(Modifier, "^age") ~ "Older Age",
      str_detect(Modifier, "^PD_duration") ~ "Longer PD Duration",
      str_detect(Modifier, "^LEDD") ~ "Higher LEDD",
      str_detect(Modifier, "^P_ADL") ~ "Worse ADL",
      str_detect(Modifier, "^P_HY") ~ "Higher Hoehn & Yahr Stage",
      str_detect(Modifier, "^P_UPDRS_part_1") ~ "Worse UPDRS Part 1",
      str_detect(Modifier, "^P_UPDRS_part_2") ~ "Worse UPDRS Part 2", 
      str_detect(Modifier, "^P_UPDRS_part_3") ~ "Worse UPDRS Part 3",
      str_detect(Modifier, "^P_UPDRS_part_4") ~ "Worse UPDRS Part 4",
      str_detect(Modifier, "^P_UPDRS_total") ~ "Worse UPDRS Total",
      str_detect(Modifier, "^P_Mean_corr_ON_time") ~ "Shorter ON Time",
      str_detect(Modifier, "^P_Mean_corr_OFF_time") ~ "Longer OFF Time",
      str_detect(Modifier, "^P_Mean_corr_Dys_weak_time") ~ "Longer Mild Dyskinesia Time",
      str_detect(Modifier, "^P_Mean_corr_Dys_sev_time") ~ "Longer Severe Dyskinesia Time",
      str_detect(Modifier, "^P_MMSE") ~ "Lower MMSE",
      str_detect(Modifier, "^P_PDQ_39") ~ "Higher PDQ39",
      TRUE ~ paste0("Worse ", str_replace(Modifier, "_standardized", ""))
    ),
    
    # Clean evaluation names with explicit "Better Outcome" direction
    Evaluation_Clean = case_when(
      Evaluation == "Mean_corr_ON_time" ~ "Longer ON Time",
      Evaluation == "Mean_corr_OFF_time" ~ "Shorter OFF Time",
      Evaluation == "Mean_corr_Dys_sev_time" ~ "Shorter Troublesome Dyskinesia",
      Evaluation == "Mean_corr_Dys_weak_time" ~ "Shorter Non-troublesome Dyskinesia", 
      Evaluation == "ADL" ~ "Better ADL",
      Evaluation == "UPDRS_part_1" ~ "Better UPDRS Part 1", 
      Evaluation == "UPDRS_part_2" ~ "Better UPDRS Part 2",
      Evaluation == "UPDRS_part_3" ~ "Better UPDRS Part 3",
      Evaluation == "UPDRS_part_4" ~ "Better UPDRS Part 4", 
      Evaluation == "UPDRS_total" ~ "Better UPDRS Total",
      Evaluation == "MMSE" ~ "Better MMSE",
      Evaluation == "PDQ_39" ~ "Better PDQ_39",
      TRUE ~ paste0("Better ", Evaluation)
    ),
    
    # Create modifier categories for grouping
    Modifier_Category = case_when(
      str_detect(Modifier, "^S_") ~ "Beta-band Metrics",
      str_detect(Modifier, "^age|^PD_duration|^sex") ~ "Demographics",
      str_detect(Modifier, "^P_UPDRS") ~ "Baseline Motor Symptoms",
      str_detect(Modifier, "^P_HY|^P_ADL|^P_MMSE") ~ "Baseline Functional Status",
      str_detect(Modifier, "time") ~ "Baseline Motor Fluctuations",
      str_detect(Modifier, "^P_PDQ_39") ~ "Baseline Quality of Life",
      TRUE ~ "Other"
    ),
    
    # Effect size categories for symbols (CI excludes zero only, Cohen's D for coloring)
    Effect_Symbol = case_when(
      abs(Partial_r_Unified) > 0.8 ~ "**",
      abs(Partial_r_Unified) > 0.5 ~ "*",
      TRUE ~ ""
    ),
    
    # Cohen's D for coloring (standardized, comparable across metrics)
    CohenD_for_Color = Effect_Size_Unified,
    
    # Truncate Cohen's D for extreme values
    CohenD_Truncated = pmax(pmin(Effect_Size_Unified, 2.0), -2.0),
    
    # Clinical interpretation labels (after reversal)
    Clinical_Interpretation = case_when(
      Interaction_Estimate_Unified > 0 ~ "Severe Cases → aDBS Superiority",
      Interaction_Estimate_Unified < 0 ~ "Mild Cases → aDBS Superiority",
      TRUE ~ "No Differential Effect"
    ),
    
    # Truncate extreme values for better color scale
    Interaction_Estimate_Truncated = pmax(pmin(Interaction_Estimate_Unified, 10), -10)
  )

# Create main heatmap with explicit labeling
create_unified_heatmap <- function(data) {
  
  # Order modifiers by specified order
  modifier_order <- c(
    "Older Age", "Longer PD Duration", "Higher LEDD", "Worse ADL", "Higher Hoehn & Yahr Stage",
    "Worse UPDRS Part 1", "Worse UPDRS Part 2", "Worse UPDRS Part 3", 
    "Worse UPDRS Part 4", "Worse UPDRS Total",
    "Shorter ON Time", "Longer OFF Time", 
    "Longer Mild Dyskinesia Time", "Longer Severe Dyskinesia Time",
    "Lower MMSE", "Higher PDQ39"
  )
  
  # Order evaluations by clinical grouping (reverse y-axis order)
  evaluation_order <- c("Better PDQ_39", "Better MMSE", 
                        "Better UPDRS Total", "Better UPDRS Part 4", "Better UPDRS Part 3",
                        "Better UPDRS Part 2", "Better UPDRS Part 1", "Better ADL", "Shorter Non-troublesome Dyskinesia",
                        "Shorter Troublesome Dyskinesia", "Shorter OFF Time", "Longer ON Time"
  )  
  
  # Prepare data for plotting
  plot_data <- data %>%
    mutate(
      Modifier_Clean = factor(Modifier_Clean, levels = modifier_order),
      Evaluation_Clean = factor(Evaluation_Clean, levels = evaluation_order)
    )
  
  # Create the heatmap (x and y axes swapped)
  p_heat <- ggplot(plot_data, aes(x = Modifier_Clean, y = Evaluation_Clean)) +
    
    # Add heatmap tiles (Cohen's D for color)
    geom_tile(aes(fill = Partial_r_Unified), color = "white", size = 0.3) +
    
    # Add symbols for special effects
    geom_text(aes(label = Effect_Symbol), size = 2, color = "black", fontface = "bold") +
    
    # Color scale (Cohen's D based, standardized effect sizes)
    scale_fill_gradient2(
      low = "#2166AC", 
      mid = "white", 
      high = "#B2182B",
      midpoint = 0,
      name = "Effect Size\n(partial r)",
      limits = c(-1, 1)
    ) +
    
    # Axis labels and title
    labs(
      title = "Treatment Effect Modifiers: Worse Baseline → Better Outcome with aDBS",
      subtitle = "Effect sizes (partial r) with statistical significance markers",
      x = "Baseline Predictors",
      y = "Clinical Outcomes",
      caption = "* |r| > 0.5, ** |r| > 0.8"
    ) +
    
    # Theme
    theme_minimal(base_size = 10) +
    theme(
      plot.title = element_text(size = 14, face = "bold"),
      plot.subtitle = element_text(size = 11, color = "gray30"),
      axis.text.x = element_text(angle = 45, hjust = 1, size = 9),
      axis.text.y = element_text(size = 8),
      axis.title = element_text(size = 11, face = "bold"),
      legend.position = "right",
      plot.caption = element_text(size = 8, color = "gray50", hjust = 0),
      panel.grid = element_blank(),
      plot.margin = margin(10, 10, 10, 10)
    ) +
    
    # Add category separators at specified positions (vertical lines after axis swap)
    # a) HY and UPDRS (4.5), b) UPDRS and ON time (9.5), 
    # c) dyskinesia severe and MMSE (13.5), d) MMSE and beta sensitivity (14.5)
    geom_vline(xintercept = c(4.5, 9.5, 13.5, 14.5), color = "gray30", size = 0.8, alpha = 0.7)
  
  return(p_heat)
}

# Check data before visualization
cat("=== Data Check Before Visualization ===\n")
cat(sprintf("Total modifier combinations: %d\n", nrow(modifier_data_clean)))

# Generate main heatmap
heatmap_plot <- create_unified_heatmap(modifier_data_clean)

# Save heatmap
ggsave("figure_treatment_modifier_heatmap_unified_fixed.pdf",
       heatmap_plot, width = 6, height = 4, bg = "white")

# Create summary statistics table
summary_stats <- modifier_data_clean %>%
  summarise(
    Total_Combinations = n(),
    Robust_Effects = sum(CI_Excludes_Zero, na.rm = TRUE),
    Large_Effects = sum(abs(Partial_r_Unified) > 0.5, na.rm = TRUE),
    Positive_Effects = sum(Interaction_Estimate_Unified > 0, na.rm = TRUE),
    Negative_Effects = sum(Interaction_Estimate_Unified < 0, na.rm = TRUE),
    Mean_Abs_Effect = mean(abs(Interaction_Estimate_Unified), na.rm = TRUE),
    Mean_AIC_Improvement = mean(AIC_Improvement, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  mutate(
    Robust_Proportion = Robust_Effects / Total_Combinations,
    Large_Proportion = Large_Effects / Total_Combinations,
    Positive_Proportion = Positive_Effects / Total_Combinations
  )

# Category-wise summary
category_stats <- modifier_data_clean %>%
  group_by(Modifier_Category) %>%
  summarise(
    N_Combinations = n(),
    Robust_Effects = sum(CI_Excludes_Zero, na.rm = TRUE),
    Large_Effects = sum(abs(Partial_r_Unified) > 0.5, na.rm = TRUE),
    Mean_Abs_Effect = mean(abs(Interaction_Estimate_Unified), na.rm = TRUE),
    Mean_Abs_r = mean(abs(Partial_r_Unified), na.rm = TRUE),
    Positive_Effects = sum(Interaction_Estimate_Unified > 0, na.rm = TRUE),
    Mean_AIC_Improvement = mean(AIC_Improvement, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  mutate(
    Robust_Proportion = Robust_Effects / N_Combinations,
    Positive_Proportion = Positive_Effects / N_Combinations
  ) %>%
  arrange(desc(Mean_Effect_Size))

# Top effects summary
top_effects_summary <- modifier_data_clean %>%
  arrange(desc(abs(Interaction_Estimate_Unified))) %>%
  slice_head(n = 20) %>%
  select(Modifier_Clean, Evaluation_Clean, Interaction_Estimate_Unified,
         Clinical_Interpretation, CI_Lower_Unified, CI_Upper_Unified,
         Effect_Size_Unified, CI_Excludes_Zero, AIC_Improvement)

# Export enhanced results
write_csv(modifier_data_clean, "results_treatment_effect_modifiers_unified_labeled_fixed.csv")
write_csv(summary_stats, "results_treatment_modifier_summary_stats_unified_fixed.csv")
write_csv(category_stats, "results_treatment_modifier_category_stats_unified_fixed.csv")
write_csv(top_effects_summary, "results_top_treatment_effects_unified_fixed.csv")

# Print summary
cat("\n=== FIXED COHEN'S D HEATMAP VISUALIZATION COMPLETE ===\n")
cat("Files created:\n")
cat("1. Treatment_Modifier_Heatmap_Unified_Fixed.pdf - Main heatmap with Cohen's D coloring\n")
cat("2. Treatment_Effect_Modifiers_Unified_Labeled_Fixed.csv - Complete data with explicit labels\n")
cat("3. Treatment_Modifier_Summary_Stats_Unified_Fixed.csv - Overall summary statistics\n")
cat("4. Treatment_Modifier_Category_Stats_Unified_Fixed.csv - Category-wise statistics\n")
cat("5. Top_Treatment_Effects_Unified_Fixed.csv - Top 20 effects by magnitude\n")

cat("\n=== UNIFIED SIGN CONVENTION SUMMARY ===\n")
print(summary_stats)

cat("\n=== CATEGORY-WISE SUMMARY ===\n")
print(category_stats)

cat("\n=== TOP EFFECTS PREVIEW ===\n")
print(top_effects_summary %>% slice_head(n = 5))

# Display plots
print(heatmap_plot)

cat("\nFixed unified heatmap visualization created successfully!\n")
cat("Red areas: Higher disease burden on baseline conditions favor aDBS\n")
cat("Blue areas: Milder disease burden on baseline conditions favor aDBS\n")